-
Notifications
You must be signed in to change notification settings - Fork 234
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add deepseek v2 support #508
add deepseek v2 support #508
Conversation
Nice job! Looking forward to the release of the deepseek-coder-v2 AWQ quantized model. |
Excellent job! I'm looking forward to the release of the DeepSeek-Coder-v2 AWQ quantized model. |
Quantization works on However, I get an error when trying to run generation. Do you have an example script that works? Perplexity:
|
An example script is:
It works. May be some issues when working with integrations such as transformers and accelerate. |
Hi,Do you have plans to release the deepseek-coder-v2-instruct-awq model? |
It's too big and I lack the hardware to do it. Will update if I have enough resources. |
I can release a quantized version of the big model once merged.
I had a similar script earlier, but ran into a few errors in transformers. Did you use the |
I use transformers==4.41.2. I just found that it didn't work after I deleted
I'll try to fix some tomorrow. |
Now it should work well. |
@TechxGenus Am I reading it correctly that multi-GPU does not work with |
Both work well for all settings now. |
@TechxGenus I'm curious, why did you decided to quantize the
|
I did not quantize
Models like Llama can generate text normally. I haven't tested finetune but I think setting it back to model.train() should work fine. This modification is because deepseek v2 uses different computations based on whether it is in training mode (https://huggingface.co/deepseek-ai/DeepSeek-Coder-V2-Instruct/blob/main/modeling_deepseek.py#L572), and |
Where can i get the quantize script? here is my quantize script from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer
model_path = '/var/mntpkg/deepseek-coder-v2-instruct'
quant_path = 'deepseek-coder-v2-ins-awq'
quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" }
# Load model
model = AutoAWQForCausalLM.from_pretrained(
model_path, **{"low_cpu_mem_usage": True, "use_cache": False}
)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
# Quantize
model.quantize(tokenizer, quant_config=quant_config)
# Save quantized model
model.save_quantized(quant_path)
tokenizer.save_pretrained(quant_path)
print(f'Model is quantized and saved at "{quant_path}"') I tried to quantize the DeepSeek-Coder-V2-Instruct-236B with the latest code version of casper-hansen:main, but it failed several times. Here are some error logs AWQ: 0%| | 0/60 [00:02<?, ?it/s]
Traceback (most recent call last):
File "/home/chatgpt/young/awq_quant.py", line 15, in <module>
model.quantize(tokenizer, quant_config=quant_config)
File "/home/chatgpt/.local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/home/chatgpt/.local/lib/python3.10/site-packages/awq/models/base.py", line 198, in quantize
self.quantizer.quantize()
File "/home/chatgpt/.local/lib/python3.10/site-packages/awq/quantize/quantizer.py", line 153, in quantize
module_config: List[Dict] = self.awq_model.get_layers_for_scaling(
File "/home/chatgpt/.local/lib/python3.10/site-packages/awq/models/deepseek_v2.py", line 51, in get_layers_for_scaling
inp=input_feat["self_attn.q_proj"],
KeyError: 'self_attn.q_proj' another one File "/home/chatgpt/.local/lib/python3.10/site-packages/awq/quantize/quantizer.py", line 335, in _compute_best_scale
self.pseudo_quantize_tensor(fc.weight.data)[0] / scales_view
File "/home/chatgpt/.local/lib/python3.10/site-packages/awq/quantize/quantizer.py", line 69, in pseudo_quantize_tensor
assert torch.isnan(w).sum() == 0
AssertionError |
This reverts commit 6b45c95.
I fixed it using #524.
|
Thanks for your quick reply. After pull the latest code in #516 and #524, i encounter another problem. File .../awq/quantize/quantize/quantizer.py , line 398, in _compute_best_clip
input_feat = input_feat[:, 0 : : input_feat.shape[1] // n_sample_token]
ValueError: slice step cannot be zero |
This is usually caused by too few tokens in the calibration data calculated by an expert of MoE model. |
Since my development environment is offline, I manually downloaded pile-val-backup-val.jsonl and modified the dataset load code. chatgpt@chatgpt:~/young$ wc -l pile-val-backup-val.jsonl
214670 pile-val-backup-val.jsonl awq/utils/calib_data.pydef load_local_dataset(jsonl_file_path):
return load_dataset('json', data_files={'validation': jsonl_file_path}, split='validation')
def get_calib_dataset(
data: Union[str, List[str], List[List[int]]] = "pileval",
tokenizer=None,
n_samples=256,
block_size=512,
split="train",
text_column="text",
local_data_path: str = None
):
if isinstance(data, str):
if local_data_path:
dataset = load_local_dataset(local_data_path)
elif data == "pileval":
dataset = load_dataset("mit-han-lab/pile-val-backup", split="validation")
else:
dataset = load_dataset(data, split=split)
dataset = dataset.shuffle(seed=42) awq/quantize/quantizer.pydef init_quant(self, n_samples=32, seqlen=512):
modules = self.awq_model.get_model_layers(self.model)
samples = get_calib_dataset(
data=self.calib_data,
tokenizer=self.tokenizer,
n_samples=n_samples,
block_size=seqlen,
split=self.split,
text_column=self.text_column,
local_data_path="/home/chatgpt/young/pile-val-backup-val.jsonl"
)
samples = torch.cat(samples, dim=0) I followed your advice decreasing |
@Cucunnber you need to specify group size of 64, as explained by @TechxGenus I am beginning testing of the big model currently on a lot of GPUs, but I cannot promise that my attempts will be successful without more credits for compute. At the moment, this is my choice of dataset (using def load_openhermes_coding():
data = load_dataset("alvarobartt/openhermes-preferences-coding", split="train")
samples = []
for sample in data:
responses = [f'{response["role"]}: {response["content"]}' for response in sample["chosen"]]
samples.append("\n".join(responses))
return samples
# Quantize
model.quantize(
tokenizer,
quant_config=quant_config,
calib_data=load_openhermes_coding(),
# n_parallel_calib_samples=32,
# max_calib_samples=128,
# max_calib_seq_len=4096
) |
Nevermind, I also got the same error. I have not seen this error before on any other model, so it will be a little hard to debug what's happening here. If you just want a quantized model, you can turn off the clipping or optimize it for deepseekv2 specifically. Clipping may need to skip a specific layer.
Potential fix (Claude 3.5 Sonnet's suggestion): |
Ah, I should have made it clearer. What needs to be reduced is
I encountered this problem before when I was quantizing a fine-tuned MoE model. It was because some experts calculated too few tokens, which caused |
@TechxGenus @casper-hansen Another error ocurs. Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 55/55 [01:07<00:00, 1.23s/it]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Token indices sequence length is longer than the specified maximum sequence length for this model (80864 > 16384). Running this sequence through the model will result in indexing errors
AWQ: 5%|███████▌ | 3/60 [22:57<7:16:12, 459.16s/it]
Traceback (most recent call last):
File "/home/chatgpt/young/awq_quant.py", line 15, in <module>
model.quantize(tokenizer, quant_config=quant_config)
File "/home/chatgpt/.local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/home/chatgpt/.local/lib/python3.10/site-packages/awq/models/base.py", line 198, in quantize
self.quantizer.quantize()
File "/home/chatgpt/.local/lib/python3.10/site-packages/awq/quantize/quantizer.py", line 156, in quantize
scales_list = [
File "/home/chatgpt/.local/lib/python3.10/site-packages/awq/quantize/quantizer.py", line 157, in <listcomp>
self._search_best_scale(self.modules[i], **layer)
File "/home/chatgpt/.local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/home/chatgpt/.local/lib/python3.10/site-packages/awq/quantize/quantizer.py", line 278, in _search_best_scale
best_scales = self._compute_best_scale(
File "/home/chatgpt/.local/lib/python3.10/site-packages/awq/quantize/quantizer.py", line 335, in _compute_best_scale
self.pseudo_quantize_tensor(fc.weight.data)[0] / scales_view
File "/home/chatgpt/.local/lib/python3.10/site-packages/awq/quantize/quantizer.py", line 69, in pseudo_quantize_tensor
assert torch.isnan(w).sum() == 0
AssertionError
|
This is a long-standing issue that you cannot quantize models when some of the weights end up as NaN values. I do not have a fix for this at the moment. @TechxGenus what do you think? |
I originally thought that this should be fixed in the latest #516. Does this error still occur after using the latest code? All possible overflow parts seem to be handled. |
I have now merged it to the main branch. I have spun up a machine to test quantization of this model after the latest improvements. |
Works on
|
If still get this error may need #532. I haven't tested this patch thoroughly. |
I don't think we need it so far. I have been able to progress quite a bit. However, as you can see from the estimate, this will take quite a long time to finish.
|
I'll give it a try too. Thanks for your support. |
@TechxGenus @casper-hansen AWQ: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 60/60 [15:32:32<00:00, 932.54s/it] I ran the awq model with vLLM. Here is my vLLM docker lancher script. model_path=deepseek-coder-v2-ins-awq
tensor_parallel_size=4
port=9190
sudo docker run --gpus='"device=4,5,6,7"' \
-v ~/.cache/huggingface:/root/.cache/huggingface \
-v $model_path:/model/ \
--env "HUGGING_FACE_HUB_TOKEN=<secret>" \
-p $port:8000 \
--ipc=host \
vllm042_dsv2-v2:latest \
--model /model/ \
--served-model-name cmwCoder \
--kv-cache-dtype fp8 \
--quantization awq \
--trust-remote-code \
--seed 42 \
--max-model-len 8192 \
--tensor-parallel-size $tensor_parallel_size Here is the error log when [rank0]: Traceback (most recent call last):
[rank0]: File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
[rank0]: return _run_code(code, main_globals, None,
[rank0]: File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
[rank0]: exec(code, run_globals)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/entrypoints/openai/api_server.py", line 169, in <module>
[rank0]: engine = AsyncLLMEngine.from_engine_args(
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 366, in from_engine_args
[rank0]: engine = cls(
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 324, in __init__
[rank0]: self.engine = self._init_engine(*args, **kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 442, in _init_engine
[rank0]: return engine_class(*args, **kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/engine/llm_engine.py", line 160, in __init__
[rank0]: self.model_executor = executor_class(
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/executor/ray_gpu_executor.py", line 300, in __init__
[rank0]: super().__init__(*args, **kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/executor/executor_base.py", line 41, in __init__
[rank0]: self._init_executor()
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/executor/ray_gpu_executor.py", line 43, in _init_executor
[rank0]: self._init_workers_ray(placement_group)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/executor/ray_gpu_executor.py", line 165, in _init_workers_ray
[rank0]: self._run_workers("load_model",
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/executor/ray_gpu_executor.py", line 234, in _run_workers
[rank0]: driver_worker_output = self.driver_worker.execute_method(
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/worker/worker_base.py", line 146, in execute_method
[rank0]: raise e
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/worker/worker_base.py", line 137, in execute_method
[rank0]: return executor(*args, **kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/worker/worker.py", line 117, in load_model
[rank0]: self.model_runner.load_model()
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/worker/model_runner.py", line 156, in load_model
[rank0]: self.model = get_model(
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/model_loader/__init__.py", line 19, in get_model
[rank0]: return loader.load_model(model_config=model_config,
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/model_loader/loader.py", line 222, in load_model
[rank0]: model = _initialize_model(model_config, self.load_config,
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/model_loader/loader.py", line 88, in _initialize_model
[rank0]: return model_class(config=model_config.hf_config,
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 458, in __init__
[rank0]: self.model = DeepseekV2Model(config, quant_config)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 422, in __init__
[rank0]: self.layers = nn.ModuleList([
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 423, in <listcomp>
[rank0]: DeepseekV2DecoderLayer(config,
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 363, in __init__
[rank0]: self.mlp = DeepseekV2MoE(config=config, quant_config=quant_config)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 104, in __init__
[rank0]: self.experts = nn.ModuleList([
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 105, in <listcomp>
[rank0]: DeepseekV2MLP(hidden_size=config.hidden_size,
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 68, in __init__
[rank0]: self.down_proj = RowParallelLinear(intermediate_size,
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/layers/linear.py", line 633, in __init__
[rank0]: self.quant_method.create_weights(self,
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/layers/quantization/awq.py", line 91, in create_weights
[rank0]: raise ValueError(
[rank0]: ValueError: The input size is not aligned with the quantized weight shape. This can be caused by too large tensor parallel size.
(RayWorkerWrapper pid=7240) ERROR 07-04 01:55:38 worker_base.py:145] Error executing method load_model. This might cause deadlock in distributed execution.
(RayWorkerWrapper pid=7240) ERROR 07-04 01:55:38 worker_base.py:145] Traceback (most recent call last):
(RayWorkerWrapper pid=7240) ERROR 07-04 01:55:38 worker_base.py:145] File "/usr/local/lib/python3.10/dist-packages/vllm/worker/worker_base.py", line 137, in execute_method
(RayWorkerWrapper pid=7240) ERROR 07-04 01:55:38 worker_base.py:145] return executor(*args, **kwargs)
(RayWorkerWrapper pid=7240) ERROR 07-04 01:55:38 worker_base.py:145] File "/usr/local/lib/python3.10/dist-packages/vllm/worker/worker.py", line 117, in load_model
(RayWorkerWrapper pid=7240) ERROR 07-04 01:55:38 worker_base.py:145] self.model_runner.load_model()
(RayWorkerWrapper pid=7240) ERROR 07-04 01:55:38 worker_base.py:145] File "/usr/local/lib/python3.10/dist-packages/vllm/worker/model_runner.py", line 156, in load_model
(RayWorkerWrapper pid=7240) ERROR 07-04 01:55:38 worker_base.py:145] self.model = get_model(
(RayWorkerWrapper pid=7240) ERROR 07-04 01:55:38 worker_base.py:145] File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/model_loader/__init__.py", line 19, in get_model
(RayWorkerWrapper pid=7240) ERROR 07-04 01:55:38 worker_base.py:145] return loader.load_model(model_config=model_config,
(RayWorkerWrapper pid=7240) ERROR 07-04 01:55:38 worker_base.py:145] File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/model_loader/loader.py", line 222, in load_model
(RayWorkerWrapper pid=7240) ERROR 07-04 01:55:38 worker_base.py:145] model = _initialize_model(model_config, self.load_config,
(RayWorkerWrapper pid=7240) ERROR 07-04 01:55:38 worker_base.py:145] File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/model_loader/loader.py", line 88, in _initialize_model
(RayWorkerWrapper pid=7240) ERROR 07-04 01:55:38 worker_base.py:145] return model_class(config=model_config.hf_config,
(RayWorkerWrapper pid=7240) ERROR 07-04 01:55:38 worker_base.py:145] File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 458, in __init__
(RayWorkerWrapper pid=7240) ERROR 07-04 01:55:38 worker_base.py:145] self.model = DeepseekV2Model(config, quant_config)
(RayWorkerWrapper pid=7240) ERROR 07-04 01:55:38 worker_base.py:145] File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 422, in __init__
(RayWorkerWrapper pid=7240) ERROR 07-04 01:55:38 worker_base.py:145] self.layers = nn.ModuleList([
(RayWorkerWrapper pid=7240) ERROR 07-04 01:55:38 worker_base.py:145] File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 423, in <listcomp>
(RayWorkerWrapper pid=7240) ERROR 07-04 01:55:38 worker_base.py:145] DeepseekV2DecoderLayer(config,
(RayWorkerWrapper pid=7240) ERROR 07-04 01:55:38 worker_base.py:145] File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 363, in __init__
(RayWorkerWrapper pid=7240) ERROR 07-04 01:55:38 worker_base.py:145] self.mlp = DeepseekV2MoE(config=config, quant_config=quant_config)
(RayWorkerWrapper pid=7240) ERROR 07-04 01:55:38 worker_base.py:145] File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 104, in __init__
(RayWorkerWrapper pid=7240) ERROR 07-04 01:55:38 worker_base.py:145] self.experts = nn.ModuleList([
(RayWorkerWrapper pid=7240) ERROR 07-04 01:55:38 worker_base.py:145] File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 105, in <listcomp>
(RayWorkerWrapper pid=7240) ERROR 07-04 01:55:38 worker_base.py:145] DeepseekV2MLP(hidden_size=config.hidden_size,
(RayWorkerWrapper pid=7240) ERROR 07-04 01:55:38 worker_base.py:145] File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 68, in __init__
(RayWorkerWrapper pid=7240) ERROR 07-04 01:55:38 worker_base.py:145] self.down_proj = RowParallelLinear(intermediate_size,
(RayWorkerWrapper pid=7240) ERROR 07-04 01:55:38 worker_base.py:145] File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/layers/linear.py", line 633, in __init__
(RayWorkerWrapper pid=7240) ERROR 07-04 01:55:38 worker_base.py:145] self.quant_method.create_weights(self,
(RayWorkerWrapper pid=7240) ERROR 07-04 01:55:38 worker_base.py:145] File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/layers/quantization/awq.py", line 91, in create_weights
(RayWorkerWrapper pid=7240) ERROR 07-04 01:55:38 worker_base.py:145] raise ValueError(
(RayWorkerWrapper pid=7240) ERROR 07-04 01:55:38 worker_base.py:145] ValueError: The input size is not aligned with the quantized weight shape. This can be caused by too large tensor parallel size.
(RayWorkerWrapper pid=7806) INFO 07-04 01:55:38 utils.py:132] reading GPU P2P access cache from /root/.config/vllm/gpu_p2p_access_cache_for_0,1,2,3,4,5,6,7.json [repeated 6x across cluster]
(RayWorkerWrapper pid=7806) Cache shape torch.Size([163840, 64]) [repeated 6x across cluster]
(RayWorkerWrapper pid=7150) ERROR 07-04 01:55:39 worker_base.py:145] Error executing method load_model. This might cause deadlock in distributed execution. [repeated 6x across cluster]
(RayWorkerWrapper pid=7150) ERROR 07-04 01:55:39 worker_base.py:145] Traceback (most recent call last): [repeated 6x across cluster]
(RayWorkerWrapper pid=7150) ERROR 07-04 01:55:39 worker_base.py:145] File "/usr/local/lib/python3.10/dist-packages/vllm/worker/worker_base.py", line 137, in execute_method [repeated 6x across cluster]
(RayWorkerWrapper pid=7150) ERROR 07-04 01:55:39 worker_base.py:145] return executor(*args, **kwargs) [repeated 6x across cluster]
(RayWorkerWrapper pid=7150) ERROR 07-04 01:55:39 worker_base.py:145] File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/model_loader/loader.py", line 222, in load_model [repeated 18x across cluster]
(RayWorkerWrapper pid=7150) ERROR 07-04 01:55:39 worker_base.py:145] self.model_runner.load_model() [repeated 6x across cluster]
(RayWorkerWrapper pid=7150) ERROR 07-04 01:55:39 worker_base.py:145] self.model = get_model( [repeated 6x across cluster]
(RayWorkerWrapper pid=7150) ERROR 07-04 01:55:39 worker_base.py:145] File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/model_loader/__init__.py", line 19, in get_model [repeated 6x across cluster]
(RayWorkerWrapper pid=7150) ERROR 07-04 01:55:39 worker_base.py:145] return loader.load_model(model_config=model_config, [repeated 6x across cluster]
(RayWorkerWrapper pid=7150) ERROR 07-04 01:55:39 worker_base.py:145] model = _initialize_model(model_config, self.load_config, [repeated 6x across cluster]
(RayWorkerWrapper pid=7150) ERROR 07-04 01:55:39 worker_base.py:145] File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/model_loader/loader.py", line 88, in _initialize_model [repeated 6x across cluster]
(RayWorkerWrapper pid=7150) ERROR 07-04 01:55:39 worker_base.py:145] return model_class(config=model_config.hf_config, [repeated 6x across cluster]
(RayWorkerWrapper pid=7150) ERROR 07-04 01:55:39 worker_base.py:145] File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/layers/linear.py", line 633, in __init__ [repeated 36x across cluster]
(RayWorkerWrapper pid=7150) ERROR 07-04 01:55:39 worker_base.py:145] self.model = DeepseekV2Model(config, quant_config) [repeated 6x across cluster]
(RayWorkerWrapper pid=7150) ERROR 07-04 01:55:39 worker_base.py:145] self.layers = nn.ModuleList([ [repeated 6x across cluster]
(RayWorkerWrapper pid=7150) ERROR 07-04 01:55:39 worker_base.py:145] File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 105, in <listcomp> [repeated 12x across cluster]
(RayWorkerWrapper pid=7150) ERROR 07-04 01:55:39 worker_base.py:145] DeepseekV2MLP(hidden_size=config.hidden_size, [repeated 12x across cluster]
(RayWorkerWrapper pid=7150) ERROR 07-04 01:55:39 worker_base.py:145] self.mlp = DeepseekV2MoE(config=config, quant_config=quant_config) [repeated 6x across cluster]
(RayWorkerWrapper pid=7150) ERROR 07-04 01:55:39 worker_base.py:145] self.experts = nn.ModuleList([ [repeated 6x across cluster]
(RayWorkerWrapper pid=7150) ERROR 07-04 01:55:39 worker_base.py:145] self.down_proj = RowParallelLinear(intermediate_size, [repeated 6x across cluster]
(RayWorkerWrapper pid=7150) ERROR 07-04 01:55:39 worker_base.py:145] self.quant_method.create_weights(self, [repeated 6x across cluster]
(RayWorkerWrapper pid=7150) ERROR 07-04 01:55:39 worker_base.py:145] File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/layers/quantization/awq.py", line 91, in create_weights [repeated 6x across cluster]
(RayWorkerWrapper pid=7150) ERROR 07-04 01:55:39 worker_base.py:145] raise ValueError( [repeated 6x across cluster]
(RayWorkerWrapper pid=7150) ERROR 07-04 01:55:39 worker_base.py:145] ValueError: The input size is not aligned with the quantized weight shape. This can be caused by too large tensor parallel size. [repeated 6x across cluster] Here is the error log when [rank0]: Traceback (most recent call last):
[rank0]: File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
[rank0]: return _run_code(code, main_globals, None,
[rank0]: File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
[rank0]: exec(code, run_globals)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/entrypoints/openai/api_server.py", line 169, in <module>
[rank0]: engine = AsyncLLMEngine.from_engine_args(
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 366, in from_engine_args
[rank0]: engine = cls(
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 324, in __init__
[rank0]: self.engine = self._init_engine(*args, **kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 442, in _init_engine
[rank0]: return engine_class(*args, **kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/engine/llm_engine.py", line 160, in __init__
[rank0]: self.model_executor = executor_class(
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/executor/ray_gpu_executor.py", line 300, in __init__
[rank0]: super().__init__(*args, **kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/executor/executor_base.py", line 41, in __init__
[rank0]: self._init_executor()
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/executor/ray_gpu_executor.py", line 43, in _init_executor
[rank0]: self._init_workers_ray(placement_group)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/executor/ray_gpu_executor.py", line 165, in _init_workers_ray
[rank0]: self._run_workers("load_model",
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/executor/ray_gpu_executor.py", line 234, in _run_workers
[rank0]: driver_worker_output = self.driver_worker.execute_method(
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/worker/worker_base.py", line 146, in execute_method
[rank0]: raise e
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/worker/worker_base.py", line 137, in execute_method
[rank0]: return executor(*args, **kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/worker/worker.py", line 117, in load_model
[rank0]: self.model_runner.load_model()
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/worker/model_runner.py", line 156, in load_model
[rank0]: self.model = get_model(
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/model_loader/__init__.py", line 19, in get_model
[rank0]: return loader.load_model(model_config=model_config,
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/model_loader/loader.py", line 222, in load_model
[rank0]: model = _initialize_model(model_config, self.load_config,
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/model_loader/loader.py", line 88, in _initialize_model
[rank0]: return model_class(config=model_config.hf_config,
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 458, in __init__
[rank0]: self.model = DeepseekV2Model(config, quant_config)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 422, in __init__
[rank0]: self.layers = nn.ModuleList([
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 423, in <listcomp>
[rank0]: DeepseekV2DecoderLayer(config,
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 363, in __init__
[rank0]: self.mlp = DeepseekV2MoE(config=config, quant_config=quant_config)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 112, in __init__
[rank0]: self.pack_params()
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 134, in pack_params
[rank0]: w1.append(expert.gate_up_proj.weight)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1709, in __getattr__
[rank0]: raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
[rank0]: AttributeError: 'MergedColumnParallelLinear' object has no attribute 'weight'. Did you mean: 'qweight'?
(RayWorkerWrapper pid=7463) INFO 07-04 01:58:41 utils.py:132] reading GPU P2P access cache from /root/.config/vllm/gpu_p2p_access_cache_for_0,1,2,3.json [repeated 2x across cluster]
(RayWorkerWrapper pid=7463) Cache shape torch.Size([163840, 64]) [repeated 2x across cluster]
(RayWorkerWrapper pid=7153) ERROR 07-04 01:58:42 worker_base.py:145] Error executing method load_model. This might cause deadlock in distributed execution. [repeated 2x across cluster]
(RayWorkerWrapper pid=7153) ERROR 07-04 01:58:42 worker_base.py:145] Traceback (most recent call last): [repeated 2x across cluster]
(RayWorkerWrapper pid=7153) ERROR 07-04 01:58:42 worker_base.py:145] File "/usr/local/lib/python3.10/dist-packages/vllm/worker/worker_base.py", line 137, in execute_method [repeated 2x across cluster]
(RayWorkerWrapper pid=7153) ERROR 07-04 01:58:42 worker_base.py:145] return executor(*args, **kwargs) [repeated 2x across cluster]
(RayWorkerWrapper pid=7153) ERROR 07-04 01:58:42 worker_base.py:145] File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/model_loader/loader.py", line 222, in load_model [repeated 6x across cluster]
(RayWorkerWrapper pid=7153) ERROR 07-04 01:58:42 worker_base.py:145] self.model_runner.load_model() [repeated 2x across cluster]
(RayWorkerWrapper pid=7153) ERROR 07-04 01:58:42 worker_base.py:145] self.model = get_model( [repeated 2x across cluster]
(RayWorkerWrapper pid=7153) ERROR 07-04 01:58:42 worker_base.py:145] File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/model_loader/__init__.py", line 19, in get_model [repeated 2x across cluster]
(RayWorkerWrapper pid=7153) ERROR 07-04 01:58:42 worker_base.py:145] return loader.load_model(model_config=model_config, [repeated 2x across cluster]
(RayWorkerWrapper pid=7153) ERROR 07-04 01:58:42 worker_base.py:145] model = _initialize_model(model_config, self.load_config, [repeated 2x across cluster]
(RayWorkerWrapper pid=7153) ERROR 07-04 01:58:42 worker_base.py:145] File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/model_loader/loader.py", line 88, in _initialize_model [repeated 2x across cluster]
(RayWorkerWrapper pid=7153) ERROR 07-04 01:58:42 worker_base.py:145] return model_class(config=model_config.hf_config, [repeated 2x across cluster]
(RayWorkerWrapper pid=7153) ERROR 07-04 01:58:42 worker_base.py:145] File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 112, in __init__ [repeated 8x across cluster]
(RayWorkerWrapper pid=7153) ERROR 07-04 01:58:42 worker_base.py:145] self.model = DeepseekV2Model(config, quant_config) [repeated 2x across cluster]
(RayWorkerWrapper pid=7153) ERROR 07-04 01:58:42 worker_base.py:145] self.layers = nn.ModuleList([ [repeated 2x across cluster]
(RayWorkerWrapper pid=7153) ERROR 07-04 01:58:42 worker_base.py:145] File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 423, in <listcomp> [repeated 2x across cluster]
(RayWorkerWrapper pid=7153) ERROR 07-04 01:58:42 worker_base.py:145] w1.append(expert.gate_up_proj.weight) [repeated 4x across cluster]
(RayWorkerWrapper pid=7153) ERROR 07-04 01:58:42 worker_base.py:145] self.mlp = DeepseekV2MoE(config=config, quant_config=quant_config) [repeated 2x across cluster]
(RayWorkerWrapper pid=7153) ERROR 07-04 01:58:42 worker_base.py:145] self.pack_params() [repeated 2x across cluster]
(RayWorkerWrapper pid=7153) ERROR 07-04 01:58:42 worker_base.py:145] File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 134, in pack_params [repeated 2x across cluster]
(RayWorkerWrapper pid=7153) ERROR 07-04 01:58:42 worker_base.py:145] File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1709, in __getattr__ [repeated 2x across cluster]
(RayWorkerWrapper pid=7153) ERROR 07-04 01:58:42 worker_base.py:145] raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") [repeated 2x across cluster]
(RayWorkerWrapper pid=7153) ERROR 07-04 01:58:42 worker_base.py:145] AttributeError: 'MergedColumnParallelLinear' object has no attribute 'weight' [repeated 2x across cluster]
[rank0]:[W CudaIPCTypes.cpp:16] Producer process has been terminated before all shared CUDA tensors released. See Note [Sharing CUDA tensors] IDK what leads to the problem, awq? or vLLM? If u guys infer the deepseekV2_awq model successfully, could you tell me how to do that. THX <3 |
vLLM has had problems with tensor parallelism on quantized models for a long time now. Please try the script I have put in the documentation and modify it to use |
I tried to load the deepseek-coder-v2-instruct-awq model using the following code with 8 L40 GPUs: from awq import AutoAWQForCausalLM
model = AutoAWQForCausalLM.from_quantized(model_id, fuse_layers=True) However, it keeps failing to load successfully. I observed that the GPU memory on gpu0 is slowly increasing. Could you provide some suggestions? |
Add support for Deepseek-V2.
I wrote it with reference to the papers of deepseek and awq. It works, but there are several potential issues:
In addition, the implementation of the fusion layer is lacking (while considering the complexity of implementing new operators, is it necessary to add fusion layers for all new architectures?)